/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Templates implementing how threads are mapped to a given tile.

*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/layout/pitch_linear.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace transform {

////////////////////////////////////////////////////////////////////////////////

/// Strip-mines a pitch-linear tile among a given number of threads, first along
/// the contiguous dimension then along the strided dimension.
///
/// The tile must be divisible by the thread count such that all threads may
/// execute the same number of iterations with the same delta to exhaustively
/// cover the tile.
///
/// This class satisfies the "RegularThreadMapping" concept.
///
/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor
/// kernels.
template <typename Shape_, int Threads, int ElementsPerAccess = 1>
struct PitchLinearStripminedThreadMap {
    /// Tensor coordinate
    using TensorCoord = layout::PitchLinearCoord;

    /// Tile shape
    using Shape = Shape_;

    /// Number of threads total
    static int const kThreads = Threads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ElementsPerAccess;

    /// Shape of access by each thread
    using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;

    /// Internal implementation details
    struct Detail {
        static_assert(!(Shape::kContiguous % kElementsPerAccess), "");

        static_assert(!((Shape::kContiguous * Shape::kStrided) %
                        (kThreads * kElementsPerAccess)),
                      "Shape must be divisible thread count.");

        /// Shape of the tile in units of vectors
        using ShapeVec = layout::PitchLinearShape<
                Shape::kContiguous / kElementsPerAccess, Shape::kStrided>;

        static_assert((Threads < ShapeVec::kContiguous &&
                       !(ShapeVec::kContiguous % kThreads)) ||
                              (!(kThreads % ShapeVec::kContiguous) &&
                               !(ShapeVec::kStrided %
                                 (kThreads / ShapeVec::kContiguous))),
                      "Shape must be divisible by number of iterations of each "
                      "thread.");
    };

    /// Number of iterations by each thread
    using Iterations = typename platform::conditional<
            Threads >= Detail::ShapeVec::kContiguous,
            layout::PitchLinearShape<
                    1,
                    // Redo the comparison here to work around divide by zero
                    // compiler error.  The compiler evaluates both path of
                    // platform::conditional.
                    (Threads >= Detail::ShapeVec::kContiguous
                             ? Detail::ShapeVec::kStrided /
                                       (kThreads /
                                        Detail::ShapeVec::kContiguous)
                             : 0)>,
            layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
                                     Detail::ShapeVec::kStrided>>::type;

    /// Interval between accesses along each dimension of the tensor's logical
    /// coordinate space (in units of Elements)
    using Delta = typename platform::conditional<
            Threads >= Detail::ShapeVec::kContiguous,
            layout::PitchLinearShape<1,
                                     kThreads / Detail::ShapeVec::kContiguous>,
            layout::PitchLinearShape<kThreads * kElementsPerAccess, 1>>::type;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space (in units of Elements)
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        return TensorCoord((thread_id % Detail::ShapeVec::kContiguous) *
                                   kElementsPerAccess,
                           thread_id / Detail::ShapeVec::kContiguous);
    }
};

/// This ThreadMap is used by GEMV
template <typename Shape, int Threads, int ElementsPerAccess = 1>
struct PitchLinearTilePolicyStripminedThreadContiguous {
    static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0,
                  "Contiguous shape must divide number of threads");

    using TensorCoord = layout::PitchLinearCoord;

    static int const kThreads = Threads;
    static int const kElementsPerAccess = ElementsPerAccess;

    using Iterations =
            layout::PitchLinearShape<Shape::kContiguous /
                                             (kThreads * kElementsPerAccess),
                                     Shape::kStrided>;

    using Delta = layout::PitchLinearShape<1, 1>;

    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        return TensorCoord(
                thread_id * Iterations::kContiguous * kElementsPerAccess, 0);
    }
};

template <typename Shape, int Threads, int ElementsPerAccess = 1>
struct PitchLinearTilePolicyStripminedThreadStrided {
    static_assert((Shape::kStrided % Threads == 0),
                  "Strided shape must divide number of threads");

    using TensorCoord = layout::PitchLinearCoord;

    static int const kThreads = Threads;
    static int const kElementsPerAccess = ElementsPerAccess;

    using Iterations =
            layout::PitchLinearShape<Shape::kContiguous / kElementsPerAccess,
                                     Shape::kStrided / kThreads>;

    using Delta = layout::PitchLinearShape<1, 1>;

    using ShapeVec = Shape;

    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        return TensorCoord(0, thread_id * Iterations::kStrided);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This ThreadMap is used by GEMVBatchedReduction
template <typename Shape, typename ThreadArrangement, int ElementsPerAccess = 1>
struct PitchLinear2DTilePolicyStripminedThreadContiguous {
    static_assert((Shape::kContiguous %
                   (ThreadArrangement::kContiguous * ElementsPerAccess)) == 0 &&
                          (Shape::kStrided % ThreadArrangement::kStrided) == 0,
                  "Shape must divide thread shape");

    using TensorCoord = layout::PitchLinearCoord;

    static int const kThreads = ThreadArrangement::kCount;
    static int const kElementsPerAccess = ElementsPerAccess;

    using Iterations = layout::PitchLinearShape<
            Shape::kContiguous /
                    (ThreadArrangement::kContiguous * kElementsPerAccess),
            Shape::kStrided / ThreadArrangement::kStrided>;

    using Delta = layout::PitchLinearShape<1, 1>;

    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        return TensorCoord(thread_id * kElementsPerAccess,
                           threadIdx.y * Iterations::kStrided);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// Policy defining a warp-raked arrangement in which a shape is partitioned
/// into contiguous elements.
///
/// This ThreadMap is used by tensor core kernels.
template <typename Shape_, int Threads, typename WarpThreadArrangement_,
          int ElementsPerAccess = 1>
struct PitchLinearWarpRakedThreadMap {
    /// Tensor coordinate
    using TensorCoord = layout::PitchLinearCoord;

    /// Tile shape
    using Shape = Shape_;

    /// Number of threads total
    static int const kThreads = Threads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ElementsPerAccess;

    /// Shape of access by each thread
    using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;

    /// Internal details made public to facilitate introspection
    struct Detail {
        /// Fixed arrangement of threads within a warp (units of threads).
        using WarpThreadArrangement = WarpThreadArrangement_;

        /// Number of threads per warp
        static int const kWarpSize = WarpThreadArrangement::kCount;

        /// Number of participating warps
        static int const kWarpCount = kThreads / kWarpSize;

        static_assert(!(Shape::kContiguous % kElementsPerAccess),
                      "Shape must be divisible by vector length.");

        /// Compute the 'shape' of the overall tile in units of vectors
        using ShapeInAccesses = layout::PitchLinearShape<
                Shape::kContiguous / kElementsPerAccess, Shape::kStrided>;

        static_assert(
                !(ShapeInAccesses::kContiguous %
                  WarpThreadArrangement::kContiguous),
                "ShapeInAccesses must be divisible by WarpThreadArrangement.");

        static_assert(
                !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided),
                "ShapeInAccesses must be divisible by WarpThreadArrangement.");

        // compute number of warp-level accesses total
        using WarpAccessIterations = layout::PitchLinearShape<
                ShapeInAccesses::kContiguous /
                        WarpThreadArrangement::kContiguous,
                ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided>;

        // Divide it into the number of warps, first partitioning the strided
        // dimension then the contiguous.
        static int const kWarpsStrided =
                (WarpAccessIterations::kStrided >= kWarpCount
                         ? kWarpCount
                         : WarpAccessIterations::kStrided);

        static int const kWarpsContiguous =
                (kWarpCount > WarpAccessIterations::kStrided
                         ? kWarpCount / kWarpsStrided
                         : 1);

        /// Arrangement of warps within a threadblock-scoped tile
        using WarpArrangement =
                layout::PitchLinearShape<kWarpsContiguous, kWarpsStrided>;
    };

    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<Detail::WarpAccessIterations::kContiguous /
                                             Detail::kWarpsContiguous,
                                     Detail::WarpAccessIterations::kStrided /
                                             Detail::kWarpsStrided>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = layout::PitchLinearShape<
            Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
            Detail::WarpThreadArrangement::kStrided>;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        int warp_id = (thread_id / Detail::kWarpSize);
        int lane_id = (thread_id % Detail::kWarpSize);

        //
        // compute warp-level offset
        //

        // This is the shape of the entire area covered by a warp's memory
        // access (in units of vectors)
        layout::PitchLinearCoord warp_footprint{
                Detail::WarpThreadArrangement::kContiguous *
                        Iterations::kContiguous,
                Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};

        // This is the offset of a specific warp (in units of vectors)
        layout::PitchLinearCoord warp_offset{
                (warp_id % Detail::kWarpsContiguous),
                (warp_id / Detail::kWarpsContiguous)};

        // This is the offset of a specific thread within a warp (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_warp{
                lane_id % Detail::WarpThreadArrangement::kContiguous,
                lane_id / Detail::WarpThreadArrangement::kContiguous};

        // This is the offset of a thread within a threadblock tile (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
                warp_footprint * warp_offset + thread_offset_in_warp;

        // This is the offset of a thread within a threadblock tile (units of
        // elements)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
                thread_offset_in_threadblock_tile_vec.contiguous() *
                        kElementsPerAccess,
                thread_offset_in_threadblock_tile_vec.strided()};

        return thread_offset_in_threadblock_tile_base;
    }
};

////////////////////////////////////////////////////////////////////////////////

/// Policy defining a warp-raked arrangement in which a shape is partitioned
/// into contiguous elements. Warps are arranged based on a stride.
///
/// This ThreadMap is used by tensor core kernels for NCxHWx layout.
template <typename Shape_, int Threads, typename WarpThreadArrangement_,
          int ElementsPerAccess = 1>
struct PitchLinearStridedWarpRakedThreadMap {
    /// Tensor coordinate
    using TensorCoord = layout::PitchLinearCoord;

    /// Tile shape
    using Shape = Shape_;

    /// Number of threads total
    static int const kThreads = Threads;

    using WarpThreadArrangement = WarpThreadArrangement_;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ElementsPerAccess;

    /// Base ThreadMap
    using BaseThreadMap = PitchLinearWarpRakedThreadMap<
            Shape, kThreads, WarpThreadArrangement, kElementsPerAccess>;

    /// Shape of access by each thread
    using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape;

    struct Detail {
        using WarpThreadArrangement = WarpThreadArrangement_;

        using WarpAccessIterations =
                typename BaseThreadMap::Detail::WarpAccessIterations;

        static int const kWarpSize = BaseThreadMap::Detail::kWarpSize;

        static int const kWarpCount = BaseThreadMap::Detail::kWarpCount;

        using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses;

        // Divide it into the number of warps, first partitioning the contiguous
        // dimension then the stride.
        static int const kWarpsContiguous =
                (WarpAccessIterations::kContiguous >= kWarpCount
                         ? kWarpCount
                         : WarpAccessIterations::kContiguous);

        static int const kWarpsStrided =
                (kWarpCount > WarpAccessIterations::kContiguous
                         ? kWarpCount / kWarpsContiguous
                         : 1);

        /// Arrangement of warps within a threadblock-scoped tile
        using WarpArrangement =
                layout::PitchLinearShape<kWarpsContiguous, kWarpsStrided>;
    };

    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<Detail::WarpAccessIterations::kContiguous /
                                             Detail::kWarpsContiguous,
                                     Detail::WarpAccessIterations::kStrided /
                                             Detail::kWarpsStrided>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = typename BaseThreadMap::Delta;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        int warp_id = (thread_id / Detail::kWarpSize);
        int lane_id = (thread_id % Detail::kWarpSize);

        //
        // compute warp-level offset
        //

        // This is the shape of the entire area covered by a warp's memory
        // access (in units of vectors)
        layout::PitchLinearCoord warp_footprint{
                Detail::WarpThreadArrangement::kContiguous *
                        Iterations::kContiguous,
                Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};

        // This is the offset of a specific warp (in units of vectors)
        layout::PitchLinearCoord warp_offset{
                (warp_id % Detail::kWarpsContiguous),
                (warp_id / Detail::kWarpsContiguous)};

        // This is the offset of a specific thread within a warp (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_warp{
                lane_id % Detail::WarpThreadArrangement::kContiguous,
                lane_id / Detail::WarpThreadArrangement::kContiguous};

        // This is the offset of a thread within a threadblock tile (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
                warp_footprint * warp_offset + thread_offset_in_warp;

        // This is the offset of a thread within a threadblock tile (units of
        // elements)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
                thread_offset_in_threadblock_tile_vec.contiguous() *
                        kElementsPerAccess,
                thread_offset_in_threadblock_tile_vec.strided()};

        return thread_offset_in_threadblock_tile_base;
    }
};

////////////////////////////////////////////////////////////////////////////////

/// Transpose the existing ThreadMap.  For example, interleaved layout is like
/// congruous in the global memory and crosswise in the shared memory.  We need
/// to transpose the coordinates between two.

template <typename ThreadMap_, typename WarpThreadArrangement_>
struct TransposePitchLinearThreadMap {
    /// Underlying ThreadMap
    using ThreadMap = ThreadMap_;

    /// Tensor coordinate
    using TensorCoord = typename ThreadMap::TensorCoord;

    /// Tile shape
    using Shape = typename ThreadMap::Shape;

    /// Number of threads total
    static int const kThreads = ThreadMap::kThreads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

    /// Shape of access by each thread
    using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;

    /// Internal details made public to facilitate introspection
    struct Detail {
        /// Fixed arrangement of threads within a warp (units of threads).
        using WarpThreadArrangement = WarpThreadArrangement_;

        /// Number of threads per warp
        static int const kWarpSize = WarpThreadArrangement::kCount;

        /// Number of participating warps
        static int const kWarpCount = kThreads / kWarpSize;

        static_assert(!(Shape::kContiguous % kElementsPerAccess),
                      "Shape must be divisible by vector length.");

        /// Arrangement of warps within a threadblock-scoped tile
        using WarpArrangement =
                layout::PitchLinearShape<ThreadMap::Detail::kWarpsStrided,
                                         ThreadMap::Detail::kWarpsContiguous>;
    };

    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
                                     ThreadMap::Iterations::kContiguous>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = layout::PitchLinearShape<
            Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
            Detail::WarpThreadArrangement::kStrided>;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space Note this is slightly different from the one of
    /// PitchLinearWarpRakedThreadMap.
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        int warp_id = (thread_id / Detail::kWarpSize);
        int lane_id = (thread_id % Detail::kWarpSize);

        //
        // compute warp-level offset
        //

        // This is the shape of the entire area covered by a warp's memory
        // access (in units of vectors)
        layout::PitchLinearCoord warp_footprint{
                Detail::WarpThreadArrangement::kContiguous *
                        Iterations::kContiguous,
                Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};

        // This is the offset of a specific warp (in units of vectors)
        // Note the order of / and %. Also the 2nd operand is kStrided.
        layout::PitchLinearCoord warp_offset{
                (warp_id / Detail::WarpArrangement::kStrided),
                (warp_id % Detail::WarpArrangement::kStrided)};

        // This is the offset of a specific thread within a warp (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_warp{
                lane_id % Detail::WarpThreadArrangement::kContiguous,
                lane_id / Detail::WarpThreadArrangement::kContiguous};

        // This is the offset of a thread within a threadblock tile (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
                warp_footprint * warp_offset + thread_offset_in_warp;

        // This is the offset of a thread within a threadblock tile (units of
        // elements)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
                thread_offset_in_threadblock_tile_vec.contiguous() *
                        kElementsPerAccess,
                thread_offset_in_threadblock_tile_vec.strided()};

        return thread_offset_in_threadblock_tile_base;
    }
};

template <typename ThreadMap_>
struct TransposePitchLinearThreadMapSimt {
    /// Underlying ThreadMap
    using ThreadMap = ThreadMap_;

    /// Tensor coordinate
    using TensorCoord = typename ThreadMap::TensorCoord;

    /// Tile shape
    using Shape = typename ThreadMap::Shape;

    /// Number of threads total
    static int const kThreads = ThreadMap::kThreads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

    static_assert(kElementsPerAccess == 1,
                  "Simt transpose requires elements per access to be 1");
    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
                                     ThreadMap::Iterations::kContiguous>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    /// Shape of access by each thread
    using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = layout::PitchLinearShape<ThreadMap::Delta::kStrided,
                                           ThreadMap::Delta::kContiguous>;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space Note this is slightly different from the one of
    /// PitchLinearWarpRakedThreadMap.
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        TensorCoord coord = ThreadMap::initial_offset(thread_id);

        return TensorCoord(coord.strided(), coord.contiguous());
    }
};

////////////////////////////////////////////////////////////////////////////////

/// Policy defining a warp-striped arrangement.  This partitions a tile into
/// vectorized memory accesses performed by each warp then distributes warps
/// across them. Warps are striped in the strided dimension and raked across the
/// contiguous dimension.
template <typename Shape_,  /// Overall shape to partition in units of elements
          int Threads,      /// Number of partiticipation threads
          typename WarpThreadArrangement_,  /// Describes the shape of one
                                            /// memory access per warp
          int ElementsPerAccess =
                  1  /// Number of elements accessed by each thread per memory
                     /// operation (i.e. vector size)
          >
struct PitchLinearWarpStripedThreadMap {
    /// Tensor coordinate
    using TensorCoord = layout::PitchLinearCoord;

    /// Tile shape
    using Shape = Shape_;

    /// Number of threads total
    static int const kThreads = Threads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ElementsPerAccess;

    /// Shape of access by each thread
    using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;

    /// Internal details made public to facilitate introspection
    struct Detail {
        /// Fixed arrangement of threads within a warp (units of threads).
        using WarpThreadArrangement = WarpThreadArrangement_;

        /// Number of threads per warp
        static int const kWarpSize = WarpThreadArrangement::kCount;

        /// Number of participating warps
        static int const kWarpCount = kThreads / kWarpSize;

        static_assert(!(Shape::kContiguous % kElementsPerAccess),
                      "Shape must be divisible by vector length.");

        /// Compute the 'shape' of the overall tile in units of vectors
        using ShapeInAccesses = layout::PitchLinearShape<
                Shape::kContiguous / kElementsPerAccess, Shape::kStrided>;

        // compute number of warp-level accesses total
        using WarpAccessIterations = layout::PitchLinearShape<
                ShapeInAccesses::kContiguous /
                        WarpThreadArrangement::kContiguous,
                ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided>;

        // Divide it into the number of warps, first partitioning the strided
        // dimension then the contiguous.
        static int const kWarpsStrided =
                (WarpAccessIterations::kStrided >= kWarpCount
                         ? kWarpCount
                         : (kWarpCount / WarpAccessIterations::kStrided));

        static int const kWarpsContiguous =
                (kWarpCount > WarpAccessIterations::kStrided
                         ? WarpAccessIterations::kContiguous / kWarpsStrided
                         : 1);

        /// Arrangement of warps within a threadblock-scoped tile
        using WarpArrangement =
                layout::PitchLinearShape<kWarpsContiguous, kWarpsStrided>;
    };

    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<Detail::WarpAccessIterations::kContiguous /
                                             Detail::kWarpsContiguous,
                                     Detail::WarpAccessIterations::kStrided /
                                             Detail::kWarpsStrided>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = layout::PitchLinearShape<
            Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
            Detail::WarpThreadArrangement::kStrided *
                    Detail::WarpArrangement::kStrided>;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        int warp_id = (thread_id / Detail::kWarpSize);
        int lane_id = (thread_id % Detail::kWarpSize);

        //
        // compute warp-level offset
        //

        // This is the shape of the entire area covered by a warp's memory
        // access (in units of vectors)
        layout::PitchLinearCoord warp_footprint{
                Detail::WarpThreadArrangement::kContiguous *
                        Iterations::kContiguous,
                Detail::WarpThreadArrangement::kStrided};

        // This is the offset of a specific warp (in units of vectors)
        layout::PitchLinearCoord warp_offset{
                (warp_id % Detail::kWarpsContiguous),
                (warp_id / Detail::kWarpsContiguous)};

        // This is the offset of a specific thread within a warp (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_warp{
                lane_id % Detail::WarpThreadArrangement::kContiguous,
                lane_id / Detail::WarpThreadArrangement::kContiguous};

        // This is the offset of a thread within a threadblock tile (units of
        // vectors)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
                warp_footprint * warp_offset + thread_offset_in_warp;

        // This is the offset of a thread within a threadblock tile (units of
        // elements)
        layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
                thread_offset_in_threadblock_tile_vec.contiguous() *
                        kElementsPerAccess,
                thread_offset_in_threadblock_tile_vec.strided()};

        return thread_offset_in_threadblock_tile_base;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Strip-mines a pitch-linear tile among a given number of threads, first along
/// the contiguous dimension then along the strided dimension, while each thread
/// access a 2D thread-tile.
///
/// The tile must be divisible by the thread count such that all threads may
/// execute the same number of iterations with the same delta to exhaustively
/// cover the tile.
///
/// This class satisfies the "RegularThreadMapping" concept.
template <typename Shape_, int Threads, typename ThreadTileShape>
struct PitchLinear2DThreadTileStripminedThreadMap;

template <typename Shape_, int Threads>
struct PitchLinear2DThreadTileStripminedThreadMap<
        Shape_, Threads, cutlass::layout::PitchLinearShape<4, 4>> {
    /// Tensor coordinate
    using TensorCoord = layout::PitchLinearCoord;

    /// Tile shape
    using Shape = Shape_;

    /// Access Shape of each thread
    using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>;
    // using ThreadAccessShape = ThreadTileShape;

    /// Number of threads total
    static int const kThreads = Threads;

    /// Extract length of each access from Layout
    static int const kElementsPerAccess = ThreadAccessShape::kContiguous;

    static_assert(!(kElementsPerAccess % 4),
                  "kElementsPerAccess, needs to be multiple of 4 (32bits)");

    /// Internal implementation details
    struct Detail {
        static_assert(!(ThreadAccessShape::kContiguous % 4),
                      "ThreadAccessShape, needs to be multiple of 4");

        static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous),
                      "");

        static_assert(
                !((Shape::kContiguous * Shape::kStrided) %
                  (kThreads * ThreadAccessShape::kCount)),
                "Shape must be divisible thread count * accesses per thread.");

        /// Shape of the tile in units of vectors
        using ShapeVec = layout::PitchLinearShape<
                Shape::kContiguous / ThreadAccessShape::kContiguous,
                Shape::kStrided / ThreadAccessShape::kStrided>;

        static_assert((Threads < ShapeVec::kContiguous &&
                       !(ShapeVec::kContiguous % kThreads)) ||
                              (!(kThreads % ShapeVec::kContiguous) &&
                               !(ShapeVec::kStrided %
                                 (kThreads / ShapeVec::kContiguous))),
                      "Shape must be divisible by number of iterations of each "
                      "thread.");
    };

    /// Number of iterations by each thread
    using Iterations = typename platform::conditional<
            Threads >= Detail::ShapeVec::kContiguous,
            layout::PitchLinearShape<
                    1,
                    // Redo the comparison here to work around divide by zero
                    // compiler error.  The compiler evaluates both path of
                    // platform::conditional.
                    (Threads >= Detail::ShapeVec::kContiguous
                             ? Detail::ShapeVec::kStrided /
                                       (kThreads /
                                        Detail::ShapeVec::kContiguous)
                             : 0)>,
            layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
                                     Detail::ShapeVec::kStrided>>::type;

    /// Interval between accesses along each dimension of the tensor's logical
    /// coordinate space (in units of Elements)
    using Delta = typename platform::conditional<
            Threads >= Detail::ShapeVec::kContiguous,
            layout::PitchLinearShape<Shape::kContiguous,
                                     kThreads * ThreadAccessShape::kStrided /
                                             Detail::ShapeVec::kContiguous>,
            layout::PitchLinearShape<kThreads * ThreadAccessShape::kContiguous,
                                     1>>::type;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space (in units of Elements)
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        return TensorCoord((thread_id % Detail::ShapeVec::kContiguous) *
                                   ThreadAccessShape::kContiguous,
                           (thread_id / Detail::ShapeVec::kContiguous) *
                                   ThreadAccessShape::kStrided);
    }
};

/// Thread Mapping a 2D threadtiled mapping as a tranposed
/// Pitchlinear2DThreadTile mapping
template <typename ThreadMap_>
struct TransposePitchLinearThreadMap2DThreadTile {
    /// Underlying ThreadMap
    using ThreadMap = ThreadMap_;

    /// Tensor coordinate
    using TensorCoord = typename ThreadMap::TensorCoord;

    /// Tile shape
    using Shape = typename ThreadMap::Shape;

    /// Number of threads total
    static int const kThreads = ThreadMap::kThreads;

    /// Extract vector length from Layout
    static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

    static_assert(kElementsPerAccess > 1,
                  "Simt transpose requires elements per access to be 1");
    ///< Iterations along each dimension (concept: PitchLinearShape)
    using Iterations =
            layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
                                     ThreadMap::Iterations::kContiguous>;

    static_assert(Iterations::kCount, "Number of iterations must be non-zero");

    /// Shape of access by each thread
    using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;

    ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
    using Delta = layout::PitchLinearShape<ThreadMap::Delta::kStrided,
                                           ThreadMap::Delta::kContiguous>;

    /// Maps thread ID to a coordinate offset within the tensor's logical
    /// coordinate space Note this is slightly different from the one of
    /// PitchLinearWarpRakedThreadMap.
    CUTLASS_HOST_DEVICE
    static TensorCoord initial_offset(int thread_id) {
        TensorCoord coord = ThreadMap::initial_offset(thread_id);
        return TensorCoord(coord.strided(), coord.contiguous());
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace transform
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
